# Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# This file is useful for reading the contents of the ops generated by ruby.
# You can read any graph defination in pb/pbtxt format generated by ruby
# or by python and then convert it back and forth from human readable to binary format.

from absl import flags
from absl import app
from absl import logging
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile


def pbtxt_to_graphdef(filename):
  with open(filename, 'r') as f:
    graph_def = tf.GraphDef()
    file_content = f.read()
    text_format.Merge(file_content, graph_def)
    tf.import_graph_def(graph_def, name='')
    tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pb', as_text=False)


def graphdef_to_pbtxt(filename):
  with gfile.FastGFile(filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pbtxt', as_text=True)
  return


def main(_):
  FLAGS = flags.FLAGS
  if FLAGS.binary:
    graphdef_to_pbtxt(
        FLAGS.graph_file
    )  # here you can write the name of the file to be converted
  else:
    pbtxt_to_graphdef(FLAGS.graph_file)


if __name__ == '__main__':
  logging.set_verbosity(logging.INFO)
  flags.DEFINE_string('graph_file', default=None, help='graph.pb file name')
  flags.DEFINE_bool('binary', default=True, help='binary graph file or not')

  app.run(main)

# and then a new file will be made in pbtxt directory.
